package entx

import (
	"errors"
	"fmt"
	"os"
	"path/filepath"
	"strings"

	"entgo.io/ent/entc"
	"entgo.io/ent/entc/gen"
	"golang.org/x/mod/modfile"
	"golang.org/x/tools/imports"

	"github.com/seal-io/walrus/utils/files"
	"github.com/seal-io/walrus/utils/strs"
)

func init() {
	fixDefaultTemplates()
	fixDefaultTemplateFuncs()
	fixDefaultTemplateRulesetAcronyms()
}

type (
	Hook     = gen.Hook
	Template = gen.Template

	Config struct {
		// ProjectDir is the root path of the Go project.
		ProjectDir string

		// Project gains from the go.mod file of ProjectDir if blank.
		Project string

		// Package holds the Go package parent path for the API schema,
		// defaults to "<project>/dao",
		// which means the schema is placing under "<project>/dao/schema".
		Package string

		// Header allows users to provide an optional header signature for
		// the generated files.
		// format: '// Code generated by "walrus", DO NOT EDIT.'.
		Header string

		// Templates specifies a list of alternative templates to execute or
		// to override the default. If nil, the default template is used.
		//
		// Note that, additional templates are executed on the Graph object and
		// the execution output is stored in a file derived by the template name.
		Templates []*Template

		// Hooks holds an optional list of Hooks to apply on the graph before/after the code-generation.
		Hooks []Hook
	}
)

func (c *Config) validateAndDefault() error {
	if c.ProjectDir == "" {
		return errors.New("invalid config: project dir is blank")
	}

	if c.Project == "" {
		project, err := getProject(c.ProjectDir)
		if err != nil {
			return fmt.Errorf("invalid config: error project getting %w", err)
		}
		c.Project = project
	}
	imports.LocalPrefix = c.Project

	if c.Package == "" {
		c.Package = strs.Join("/", c.Project, "dao")
	}

	if !strings.HasPrefix(c.Package, c.Project) {
		return errors.New("invalid config: package must below project")
	}

	if c.Header == "" {
		c.Header = `// Code generated by "walrus"", DO NOT EDIT.`
	}

	return nil
}

func Generate(cfg Config) (err error) {
	// Prepare working dir.
	generatedDir := files.TempDir("entio-generated-*")
	defer func() {
		if err != nil {
			return
		}
		_ = os.RemoveAll(generatedDir)
	}()
	newGeneratedDir := filepath.Join(generatedDir, "/new")
	oldGeneratedDir := filepath.Join(generatedDir, "/old")

	// Validates.
	if err = cfg.validateAndDefault(); err != nil {
		return err
	}

	baseDir := filepath.Join(cfg.ProjectDir, strings.TrimPrefix(cfg.Package, cfg.Project))
	schemaDir := filepath.Join(baseDir, "schema")
	modelDir := filepath.Join(baseDir, "model")
	schemaPkg := strs.Join("/", cfg.Package, "schema")
	modelPkg := strs.Join("/", cfg.Package, "model")

	// Create configuration.
	c := &gen.Config{
		Features:  loadFeatures(),
		Storage:   loadStorage(),
		Templates: append(cfg.Templates, loadTemplate()),
		Hooks:     append(cfg.Hooks, loadHooks()...),
		Header:    cfg.Header,
		Target:    newGeneratedDir,
		Schema:    schemaPkg,
		Package:   modelPkg,
	}

	// Load generation graph.
	g, err := entc.LoadGraph(schemaDir, c)
	if err != nil {
		return err
	}

	// Generate.
	if err = g.Gen(); err != nil {
		return fmt.Errorf("error generating: %w", err)
	}

	// Save generated.
	if err = os.Rename(modelDir, oldGeneratedDir); err != nil {
		if !strings.Contains(err.Error(), "no such file or directory") {
			return fmt.Errorf("error cleaning stale generated files: %w", err)
		}
	}

	defer func() {
		if err != nil {
			_ = os.Rename(oldGeneratedDir, modelDir)
		}
	}()

	if err = os.Rename(newGeneratedDir, modelDir); err != nil {
		return fmt.Errorf("error move new generated files to %s: %w", modelDir, err)
	}

	return nil
}

func loadFeatures() []gen.Feature {
	return []gen.Feature{
		gen.FeatureIntercept,
		gen.FeatureSnapshot,
		gen.FeatureSchemaConfig,
		gen.FeatureLock,
		gen.FeatureModifier,
		gen.FeatureExecQuery,
		gen.FeatureUpsert,
		gen.FeatureVersionedMigration,
	}
}

func loadStorage() *gen.Storage {
	stg, err := gen.NewStorage("sql")
	if err != nil {
		panic(fmt.Errorf("error creating storage: %w", err))
	}

	return stg
}

func getProject(projectDir string) (string, error) {
	mfn := filepath.Join(projectDir, "go.mod")

	mfb, err := os.ReadFile(mfn)
	if err != nil {
		return "", fmt.Errorf("error reading the go.mod: %w", err)
	}

	mf, err := modfile.Parse(mfn, mfb, nil)
	if err != nil {
		return "", fmt.Errorf("error parsing the go.mod: %w", err)
	}

	return mf.Module.Mod.Path, nil
}
